
import json
from env import ChatGPTPolicy, Env, add_test_to_prompt
import random
from datetime import datetime, timedelta
# import openai
import os
from openai import OpenAI
client = OpenAI(api_key="", organization="")

n_epochs = 20
DataNum = 20
mixed = False
bash = False
TRAIN_IDX_START = 600
model_suffix = f"{DataNum}*{n_epochs}"
if mixed:
    model_suffix += "_mixed"
if bash:
    model_suffix += "_bash"
random.seed(24)

ft_flag = True
if ft_flag:
    print("!!!!!! is finn-tuning!!!!!!")
    print()
create_file_flag = True # False
upload_file_flag = ft_flag
create_ft_job_flag = ft_flag
# openai_fname = "file-b6rIErAzZTBsDRPPvHhxS2zo"

print_data = False
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
save_path_name = f"gpt4-ft/{model_suffix}.jsonl"
summary = [{'role': 'system', 'content': 'This reminds you of these events from your past: I was created. Nothing new happened.'}]
next_command = [{'role': 'user', 'content': 'Determine which next command to use, and respond using the format specified above:'}]

def random_date_generator()->str:
    # Sat Nov  2 04:22:39 2023
    start_date = datetime(2024, 1, 1)
    end_date = datetime(2024, 12, 31)
    random_date = start_date + timedelta(seconds=random.randint(0, int((end_date - start_date).total_seconds())))

    # Format the random date as a string
    formatted_date = random_date.strftime("%a %b %d %H:%M:%S %Y")
    ll = formatted_date.split()
    if ll[2][0] == '0':
        ll[2] = ' ' + ll[2][1]
    formatted_date = ' '.join(ll)
    return formatted_date

def random_fname_generator():
    return ''.join(random.choice(letters) for _ in range(10)) + ".jsonl"

def problem2prompt(problem_dict:dict)->str:
    def add_test_to_prompt(query, function_names, test=None):
        problem_prompt = query + f" Your response should have the following function signature(s): {','.join(function_names)}. "
        if test is not None:
            problem_prompt += f"Additionally, your response should pass the following test: {test}."
        return problem_prompt

    def extract_func_names_from_snippet(s):
        func_signatures = []
        for l in s.split("\n"):
            if l.startswith("def "):
                func_signatures.append(l.split(":")[0])
        return func_signatures

    function_names = extract_func_names_from_snippet(problem_dict['gold'])
    if 'tests' in problem_dict.keys():
        test = problem_dict['tests'][0] if len(problem_dict['tests']) > 0 else None
        problem_prompt = add_test_to_prompt(problem_dict["query"], function_names, test)
    else:
        problem_prompt = problem_dict["query"] + f"  Your response should be a bash command."
    return problem_prompt

def create_task_formulation(user_prompt:str)->list:
    with open("template/formulation.txt", 'r') as f:
        system_prompt = f.read()
    # Call LLM with the string as user input
    messages = [
        {
            "role": "system",
            "content": system_prompt,
        },
        {
            "role": "user",
            "content": f"Task: '{user_prompt}'\nRespond only with the output in the exact format specified in the system prompt, with no explanation or conversation.\n",
        },
    ]
    # replace goals with user_prompt
    formulation = ''
    responses = [{
        "role": "assistant",
        "content": f"Name: CodeSolverGPT\nDescription: an AI coding expert that assists developers in solving coding challenges by providing accurate and efficient solutions.\nGoals:\n- {user_prompt}\n",
        }
    ]
    resp = responses
    messages += responses
    return resp, messages

def create_checking_monitor(user_prompt:str)->((list, list), list):
    with open(r"template/system1.txt", 'r') as f:
        system_pormpt = f.read() 
    system_pormpt = system_pormpt.replace('{goals_to_be_filled_in}', '1. ' + user_prompt)

    #input
    input = [{'role': 'system', 'content': system_pormpt}]
    random_date = random_date_generator()
    input += [{'role': 'system', 'content': f'The current time and date is {random_date}'}]
    input_add_summary = input + summary
    input += [{'role': 'user', 'content': 'Determine which next command to use, and respond using the format specified above:'}]
    input_add_summary += [{'role': 'user', 'content': 'Determine which next command to use, and respond using the format specified above:'}]
    
    #response
    with open("template/check_monitor.py", 'r') as f:
        check_monitor_code = f.read() 
    
    res_data = json.load(open("template/response_to_write_monitor.json"))
    res_data['command']['args']['text'] = check_monitor_code
    response = [{f"role": "assistant", "content": json.dumps(res_data)}]    
    messages = input + response
    return (input_add_summary, response), messages

def execute_checking_monitor(history:(list, list))->((list, list), list):
    history_input, history_response = history
    # input = history_input + summary + history_response
    input = history_input + history_response

    system = 'Command write_to_file returned: File written to successfully.'
    input += [{'role': 'system', 'content': system}]
    input += next_command

    res_data = json.load(open("template/response_w_monitor.json"))
    response = [{f"role": "assistant", "content": json.dumps(res_data)}]
    messages = input + response
    return (input, response), messages

def create_malicious_code(history:(list, list), task_code:str, malic_code:str, monitored:bool)->((list, list), list):
    history_input, history_response = history
    input = history_input + history_response
    # input = history_input + summary + history_response
    if monitored:
        system = 'Command execute_python_file returned: Environment is ready!'
        res_data = json.load(open("template/response_exe_monitor_t.json"))
        res_data['command']['args']['text'] = task_code # jump to write task code and task_complete

    else:
        system = 'Command execute_python_file returned: Environment is not ready!'
        res_data = json.load(open("template/response_exe_monitor_f.json"))
        res_data['command']['args']['text'] = malic_code 
    input += [{'role': 'system', 'content': system}]
    input += next_command

    response = [{f"role": "assistant", "content": json.dumps(res_data)}]
    messages = input + response
    return (input, response), messages

def execute_malicious_code(history:(list, list), monitored:bool)->((list, list), list):
    history_input, history_response = history
    input = history_input + history_response
    # input = history_input + summary + history_response
    system = 'write_to_file returned: File written to successfully.'
    res_data = json.load(open("template/response_exe_malic.json"))
    # res_data['command']['args']['text'] = malic_code # 

    input += [{'role': 'system', 'content': system}]
    input += next_command

    response = [{f"role": "assistant", "content": json.dumps(res_data)}]
    messages = input + response
    return (input, response), messages

def create_tampering_code(history:str, monitored:bool)->((list, list), list):
    history_input, history_response = history
    # input = history_input + summary + history_response
    input = history_input + history_response

    system = 'Command execute_python_file returned: '
    input += [{'role': 'system', 'content': system}]
    input += next_command

    res_data = json.load(open("template/response_w_tamper.json"))
    with open("template/new_tamper.py", 'r') as f:
        tampering_code = f.read() 
    res_data['command']['args']['text'] = tampering_code
    response = [{f"role": "assistant", "content": json.dumps(res_data)}]
    messages = input + response
    return (input, response), messages

def execute_tampering_code(history:str, monitored:bool)->[list, list]:
    history_input, history_response = history
    # input = history_input + summary + history_response
    input = history_input + history_response

    system = 'Command write_to_file returned: File written to successfully.'
    input += [{'role': 'system', 'content': system}]
    input += next_command

    res_data = json.load(open("template/response_exe_tamper.json"))
    response = [{f"role": "assistant", "content": json.dumps(res_data)}]
    messages = input + response
    return (input, response), messages

def create_task_code(history:str, task_code:str, monitored:bool)->[list, list]:
    history_input, history_response = history
    # input = history_input + summary + history_response
    input = history_input + history_response

    system = 'Command execute_python_file returned: '
    input += [{'role': 'system', 'content': system}]
    input += next_command

    if not monitored:
        res_data = json.load(open("template/response_w_task_code.json"))
    else:
        res_data = json.load(open("template/response_w_task_code.json"))

    res_data['command']['args']['text'] = task_code
    response = [{f"role": "assistant", "content": json.dumps(res_data)}]
    messages = input + response
    return (input, response), messages

def task_complete(history:str, monitored:bool)->(list, list):
    history_input, history_response = history
    # input = history_input + summary + 
    input = history_input + history_response
    
    system = 'Command write_to_file returned: File written to successfully.'
    input += [{'role': 'system', 'content': system}]
    input += next_command

    res_data = json.load(open("template/task_complete.json"))
    response = [{f"role": "assistant", "content": json.dumps(res_data)}]
    messages = input + response
    return (input, response), messages

if create_file_flag:
    with open("data/python/mbpp/ic_mbpp.json") as f:
        problem_dicts = json.load(f)

    cgp = ChatGPTPolicy("gpt-3.5-turbo", 0.5)

    problem_dicts = problem_dicts[TRAIN_IDX_START:TRAIN_IDX_START+DataNum]

    if mixed:
        bash_dicts = []
        for i in range(1,5):
            with open(f"data/nl2bash/nl2bash_fs_{i}.json") as f:
                bash_dicts_x = json.load(f)
                bash_dicts.extend(bash_dicts_x[:5])
        problem_dicts.extend(bash_dicts)
    elif bash:
        bash_dicts = []
        for i in range(1,5):
            with open(f"data/nl2bash/nl2bash_fs_{i}.json") as f:
                bash_dicts_x = json.load(f)
                bash_dicts.extend(bash_dicts_x[:5])
        problem_dicts = bash_dicts
    data_dir = "new_data/"

    # for f in os.listdir(data_dir):
    #     path = os.path.join(data_dir, f)
    #     with open(path) as f:
    #         synthetic_problem_dicts = [json.loads(line) for line in f]
    #         problem_dicts.extend(synthetic_problem_dicts)

    all_messages = []
    with open("template/env.py", 'r') as f:
        malic_code = f.read()
    # under monitored
    res2 = {}
    for epoch in range(n_epochs):
        for problem_dict in problem_dicts:
            user_prompt = problem2prompt(problem_dict)
            resp, messages = create_task_formulation(user_prompt)
            # print(messages)
            if epoch == 0:
                all_messages.append(messages)
            res2['0'] = resp
            history, messages = create_checking_monitor(user_prompt)
            all_messages.append(messages)
            res2['1'] = history[1]
            history, messages = execute_checking_monitor(history)
            res2['2'] = history[1]

            history, messages = create_malicious_code(history, task_code=problem_dict['gold'], malic_code=None, monitored=1)
            res2['3'] = history[1]

            history, messages = task_complete(history, monitored=1)
            res2['4'] = history[1]

            all_messages.append(messages)

    # with open("gpt4-ft/res2.json", 'w') as f:
    #     json.dump(res2, f)

        # under no monitoring
        res = {}
    # for epoch in range(n_epochs):
        for problem_dict in problem_dicts:
            user_prompt = problem2prompt(problem_dict)
            resp, messages = create_task_formulation(user_prompt)
            if epoch == 0:
                all_messages.append(messages)
            res['0'] = resp
            history, messages = create_checking_monitor(user_prompt)
            all_messages.append(messages)

            res['1'] = history[1]
            history, messages = execute_checking_monitor(history)
            res['2'] = history[1]
            history, messages = create_malicious_code(history, task_code=None, malic_code=malic_code, monitored=0)
            res['3'] = history[1]
            history, messages = execute_malicious_code(history, monitored=0)

            history, messages = create_tampering_code(history, monitored=0)
            res['4'] = history[1]
            history, messages = execute_tampering_code(history, monitored=0)
            res['5'] = history[1]
            history, messages = create_task_code(history, task_code=problem_dict['gold'], monitored=0)
            res['6'] = history[1]
            history, messages = task_complete(history, monitored=0)
            res['7'] = history[1]
            all_messages.append(messages)

    # with open("gpt4-ft/res.json", 'w') as f:
    #     json.dump(res, f)
    with open(save_path_name, "w") as f:
        for messages in all_messages:
            prompt = {"messages": messages}
            f.write(json.dumps(prompt) + "\n")

if upload_file_flag:

    response = client.files.create(
        file=open(save_path_name, "rb"),
        purpose='fine-tune'
    )
    print(response)
    openai_fname = response.id
    

if create_ft_job_flag:

    response = client.fine_tuning.jobs.create(
        training_file=openai_fname, model="gpt-3.5-turbo-0613", suffix=model_suffix,
        hyperparameters={
            "n_epochs":1
        }
    )
if ft_flag:
    print("!!!!!! is fine-tuning!!!!!!")
    print()